
# Copyright (c) 2018-2023, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F


class RandomNetworkAdversary(nn.Module):

    def __init__(self, num_envs, in_dims, out_dims, softmax_bins, device):
        super(RandomNetworkAdversary, self).__init__()

        """
        Class to add random action to the action generated by the policy. 
        The output is binned to 32 bins per channel and we do softmax over 
        these bins to figure out the most likely joint angle.

        Note: OpenAI et al. 2019 found out that if they used a continuous space 
              and a tanh non-linearity, actions would always be close to 0. 
              Section B.3 https://arxiv.org/abs/1910.07113

        Q: Why do we need dropouts here? 

        A: If we were using a CPU-based simulator as in OpenAI et al. 2019, we 
           will use a different RNA network for different CPU. However, 
           this is not feasible for a GPU-based simulator as that would mean 
           creating N_envs RNA networks which will overwhelm the GPU-memory. 
           Therefore, dropout is a nice approximation of this by re-sampling 
           weights of the same neural network for each different env on the GPU. 
        """

        self.in_dims  = in_dims 
        self.out_dims = out_dims
        self.softmax_bins = softmax_bins
        self.num_envs = num_envs

        self.device = device 
       
        self.num_feats1 = 512
        self.num_feats2 = 1024

        # Sampling random probablities for dropout masks 
        dropout_probs = torch.rand((2, ))

        # Setting up the RNA neural network here    

        # First layer

        self.fc1 = nn.Linear(in_dims, self.num_feats1).to(self.device)

        self.dropout_masks1 = torch.bernoulli(torch.ones((self.num_envs, \
            self.num_feats1)), p=dropout_probs[0]).to(self.device)

        self.fc1_1 = nn.Linear(self.num_feats1, self.num_feats1).to(self.device)

        # Second layer 
        self.fc2 = nn.Linear(self.num_feats1, self.num_feats2).to(self.device)

        self.dropout_masks2 = torch.bernoulli(torch.ones((self.num_envs, \
            self.num_feats2)), p=dropout_probs[1]).to(self.device)

        self.fc2_1 = nn.Linear(self.num_feats2, self.num_feats2).to(self.device)

        # Last layer 
        self.fc3 = nn.Linear(self.num_feats2, out_dims*softmax_bins).to(self.device)

        # This is needed to reset weights and dropout masks 
        self._refresh()

    def _refresh(self):

        self._init_weights()
        self.eval()
        self.refresh_dropout_masks()

    def _init_weights(self):

        print('initialising weights for random network')

        nn.init.kaiming_uniform_(self.fc1.weight)
        nn.init.kaiming_uniform_(self.fc1_1.weight)
        nn.init.kaiming_uniform_(self.fc2.weight)
        nn.init.kaiming_uniform_(self.fc2_1.weight)
        nn.init.kaiming_uniform_(self.fc3.weight)

        return

    def refresh_dropout_masks(self):

        dropout_probs = torch.rand((2, ))

        self.dropout_masks1 = torch.bernoulli(torch.ones((self.num_envs, self.num_feats1)), \
            p=dropout_probs[0]).to(self.dropout_masks1.device)

        self.dropout_masks2 = torch.bernoulli(torch.ones((self.num_envs, self.num_feats2)), \
            p=dropout_probs[1]).to(self.dropout_masks2.device)

        return
   
    def forward(self, x):

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc1_1(x)
        x = self.dropout_masks1 * x 

        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc2_1(x)
        x = self.dropout_masks2 * x 

        x = self.fc3(x)

        x = x.view(-1, self.out_dims, self.softmax_bins)
        output = F.softmax(x, dim=-1)

        # We have discretised the joint angles into bins 
        # Now we pick up the bin for each joint angle 
        # corresponding to the highest softmax value / prob.

        return output


if __name__ == "__main__":

    num_envs = 1024
    RNA = RandomNetworkAdversary(num_envs=num_envs, in_dims=16, out_dims=16, softmax_bins=32, device='cuda')

    x = torch.tensor(torch.randn(num_envs, 16).to(RNA.device))
    y = RNA(x)
    import ipdb; ipdb.set_trace()

    

